import torch
import torch.nn as nn


class LinearAgg(torch.nn.Module):
    def __init__(self, d, weighted=False):
        super().__init__()
        self.linear = nn.Linear(d, 1, bias=False)
        self.weight = nn.Parameter(torch.tensor(1.0), requires_grad=True) if weighted else None

    def forward(self, Adj, x):
        x = x.float()
        x = self.linear(x)
        x_copy = x.clone()
        if self.weight is not None:
            x = x_copy + self.weight * (Adj @ x)
        else:
            x = x_copy + Adj @ x

        return torch.sigmoid(x), self.linear.weight


class OptimalMP(torch.nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear = nn.Linear(d, 1, bias=False)
        self.thres = nn.Parameter(torch.tensor(0.2), requires_grad=True)

    def forward(self, Adj, x):
        x = x.float()
        x = self.linear(x)
        x_copy = x.clone()
        x[x > self.thres] = self.thres
        x[x < -self.thres] = -self.thres
        x = x_copy + Adj @ x

        return torch.sigmoid(x), self.linear.weight, self.thres


class OptimalMP_Laplacian(torch.nn.Module):
    def __init__(self, d, weighted=False):
        super().__init__()
        self.linear = nn.Linear(d, 1, bias=False)
        self.thres_1 = nn.Parameter(torch.tensor(2.0), requires_grad=True)
        self.thres_2 = nn.Parameter(torch.tensor(2.0), requires_grad=True)

    def forward(self, Adj, x):
        x = x.float()
        x[x > self.thres_1] = self.thres_1
        x[x < -self.thres_1] = -self.thres_1
        x = self.linear(x)
        x_copy = x.clone()
        x[x > self.thres_2] = self.thres_2
        x[x < -self.thres_2] = -self.thres_2
        x = x_copy + Adj @ x

        return torch.sigmoid(x), self.linear.weight, self.thres_1, self.thres_2


class OptimalMP_Laplacian_Psi(torch.nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear = nn.Linear(d, 1, bias=False)
        self.thres_1 = nn.Parameter(torch.tensor(2.0), requires_grad=True)

    def forward(self, Adj, x):
        x = x.float()
        x[x > self.thres_1] = self.thres_1
        x[x < -self.thres_1] = -self.thres_1
        x = self.linear(x)
        x_copy = x.clone()
        x = x_copy + Adj @ x

        return torch.sigmoid(x), self.linear.weight, self.thres_1


class OptimalMP_Laplacian_Phi(torch.nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear = nn.Linear(d, 1, bias=False)
        self.thres_2 = nn.Parameter(torch.tensor(2.0), requires_grad=True)

    def forward(self, Adj, x):
        x = x.float()
        x = self.linear(x)
        x_copy = x.clone()
        x[x > self.thres_2] = self.thres_2
        x[x < -self.thres_2] = -self.thres_2
        x = x_copy + Adj @ x

        return torch.sigmoid(x), self.linear.weight, self.thres_2
